1use core::panic;
5use socket2::{Domain, SockAddr, Socket, Type};
6use std::{
7    collections::HashMap,
8    net::{SocketAddr, TcpListener},
9    os::fd::{AsFd, FromRawFd},
10    sync::Arc,
11};
12use tokio::sync::Mutex;
13
14use bytes::Bytes;
15use derive_builder::Builder;
16use futures::{SinkExt, StreamExt};
17use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
18use serde::{Deserialize, Serialize};
19use tokio::{
20    io::AsyncWriteExt,
21    sync::{mpsc, oneshot},
22    time,
23};
24use tokio_util::codec::{FramedRead, FramedWrite};
25
26use super::{
27    CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
28    StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
29};
30use crate::engine::AsyncEngineContext;
31use crate::pipeline::{
32    PipelineError,
33    network::{
34        ResponseService, ResponseStreamPrologue,
35        codec::{TwoPartMessage, TwoPartMessageType},
36        tcp::StreamType,
37    },
38};
39use crate::{ErrorContext, Result, error};
40
41#[allow(dead_code)]
42type ResponseType = TwoPartMessage;
43
44#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
45pub struct ServerOptions {
46    #[builder(default = "0")]
47    pub port: u16,
48
49    #[builder(default)]
50    pub interface: Option<String>,
51}
52
53impl ServerOptions {
54    pub fn builder() -> ServerOptionsBuilder {
55        ServerOptionsBuilder::default()
56    }
57}
58
59pub struct TcpStreamServer {
63    local_ip: String,
64    local_port: u16,
65    state: Arc<Mutex<State>>,
66}
67
68#[allow(dead_code)]
75struct RequestedSendConnection {
76    context: Arc<dyn AsyncEngineContext>,
77    connection: oneshot::Sender<Result<StreamSender, String>>,
78}
79
80struct RequestedRecvConnection {
81    context: Arc<dyn AsyncEngineContext>,
82    connection: oneshot::Sender<Result<StreamReceiver, String>>,
83}
84
85#[derive(Default)]
102struct State {
103    tx_subjects: HashMap<String, RequestedSendConnection>,
104    rx_subjects: HashMap<String, RequestedRecvConnection>,
105    handle: Option<tokio::task::JoinHandle<Result<()>>>,
106}
107
108impl TcpStreamServer {
109    pub fn options_builder() -> ServerOptionsBuilder {
110        ServerOptionsBuilder::default()
111    }
112
113    pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
114        let local_ip = match options.interface {
115            Some(interface) => {
116                let interfaces: HashMap<String, std::net::IpAddr> =
117                    list_afinet_netifas()?.into_iter().collect();
118
119                interfaces
120                    .get(&interface)
121                    .ok_or(PipelineError::Generic(format!(
122                        "Interface not found: {}",
123                        interface
124                    )))?
125                    .to_string()
126            }
127            None => local_ip()
128                .or_else(|err| match err {
129                    Error::LocalIpAddressNotFound => {
130                        local_ipv6()
132                    }
133                    _ => Err(err),
134                })
135                .unwrap()
136                .to_string(),
137        };
138
139        let state = Arc::new(Mutex::new(State::default()));
140
141        let local_port = Self::start(local_ip.clone(), options.port, state.clone())
142            .await
143            .map_err(|e| {
144                PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
145            })?;
146
147        tracing::debug!("tcp transport service on {local_ip}:{local_port}");
148
149        Ok(Arc::new(Self {
150            local_ip,
151            local_port,
152            state,
153        }))
154    }
155
156    #[allow(clippy::await_holding_lock)]
157    async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
158        let addr = format!("{}:{}", local_ip, local_port);
159        let state_clone = state.clone();
160        let mut guard = state.lock().await;
161        if guard.handle.is_some() {
162            panic!("TcpStreamServer already started");
163        }
164        let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
165        let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
166        guard.handle = Some(handle);
167        drop(guard);
168        let local_port = ready_rx.await??;
169        Ok(local_port)
170    }
171}
172
173#[async_trait::async_trait]
175impl ResponseService for TcpStreamServer {
176    async fn register(&self, options: StreamOptions) -> PendingConnections {
197        let address = format!("{}:{}", self.local_ip, self.local_port);
200        tracing::debug!("Registering new TcpStream on {}", address);
201
202        let send_stream = if options.enable_request_stream {
203            let sender_subject = uuid::Uuid::new_v4().to_string();
204
205            let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
206
207            let connection_info = RequestedSendConnection {
208                context: options.context.clone(),
209                connection: pending_sender_tx,
210            };
211
212            let mut state = self.state.lock().await;
213            state
214                .tx_subjects
215                .insert(sender_subject.clone(), connection_info);
216
217            let registered_stream = RegisteredStream {
218                connection_info: TcpStreamConnectionInfo {
219                    address: address.clone(),
220                    subject: sender_subject.clone(),
221                    context: options.context.id().to_string(),
222                    stream_type: StreamType::Request,
223                }
224                .into(),
225                stream_provider: pending_sender_rx,
226            };
227
228            Some(registered_stream)
229        } else {
230            None
231        };
232
233        let recv_stream = if options.enable_response_stream {
234            let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
235            let receiver_subject = uuid::Uuid::new_v4().to_string();
236
237            let connection_info = RequestedRecvConnection {
238                context: options.context.clone(),
239                connection: pending_recver_tx,
240            };
241
242            let mut state = self.state.lock().await;
243            state
244                .rx_subjects
245                .insert(receiver_subject.clone(), connection_info);
246
247            let registered_stream = RegisteredStream {
248                connection_info: TcpStreamConnectionInfo {
249                    address: address.clone(),
250                    subject: receiver_subject.clone(),
251                    context: options.context.id().to_string(),
252                    stream_type: StreamType::Response,
253                }
254                .into(),
255                stream_provider: pending_recver_rx,
256            };
257
258            Some(registered_stream)
259        } else {
260            None
261        };
262
263        PendingConnections {
264            send_stream,
265            recv_stream,
266        }
267    }
268}
269
270async fn tcp_listener(
277    addr: String,
278    state: Arc<Mutex<State>>,
279    read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
280) -> Result<()> {
281    let listener = tokio::net::TcpListener::bind(&addr)
282        .await
283        .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
284
285    let listener = match listener {
286        Ok(listener) => {
287            let addr = listener
288                .local_addr()
289                .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
290                .unwrap();
291
292            read_tx
293                .send(Ok(addr.port()))
294                .expect("Failed to send ready signal");
295
296            listener
297        }
298        Err(e) => {
299            read_tx.send(Err(e)).expect("Failed to send ready signal");
300            return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
301        }
302    };
303
304    loop {
305        let (stream, _addr) = match listener.accept().await {
311            Ok((stream, _addr)) => (stream, _addr),
312            Err(e) => {
313                tracing::warn!("failed to accept tcp connection: {}", e);
315                eprintln!("failed to accept tcp connection: {}", e);
316                continue;
317            }
318        };
319
320        match stream.set_nodelay(true) {
321            Ok(_) => (),
322            Err(e) => {
323                tracing::warn!("failed to set tcp stream to nodelay: {}", e);
324            }
325        }
326
327        match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
328            Ok(_) => (),
329            Err(e) => {
330                tracing::warn!("failed to set tcp stream to linger: {}", e);
331            }
332        }
333
334        tokio::spawn(handle_connection(stream, state.clone()));
335    }
336
337    async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
340        let result = process_stream(stream, state).await;
341        match result {
342            Ok(_) => tracing::trace!("successfully processed tcp connection"),
343            Err(e) => {
344                tracing::warn!("failed to handle tcp connection: {}", e);
345                #[cfg(debug_assertions)]
346                eprintln!("failed to handle tcp connection: {}", e);
347            }
348        }
349    }
350
351    async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
354        let (read_half, write_half) = tokio::io::split(stream);
356
357        let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
359        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
360
361        let first_message = framed_reader
364            .next()
365            .await
366            .ok_or(error!("Connection closed without a ControlMessage"))??;
367
368        let handshake: CallHomeHandshake = match first_message.header() {
371            Some(header) => serde_json::from_slice(header).map_err(|e| {
372                error!(
373                    "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
374                )
375            })?,
376            None => {
377                return Err(error!("Expected ControlMessage, got DataMessage"));
378            }
379        };
380
381        match handshake.stream_type {
383            StreamType::Request => process_request_stream().await,
384            StreamType::Response => {
385                process_response_stream(handshake.subject, state, framed_reader, framed_writer)
386                    .await
387            }
388        }
389    }
390
391    async fn process_request_stream() -> Result<()> {
392        Ok(())
393    }
394
395    async fn process_response_stream(
396        subject: String,
397        state: Arc<Mutex<State>>,
398        mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
399        writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
400    ) -> Result<()> {
401        let response_stream = state
402            .lock().await
403            .rx_subjects
404            .remove(&subject)
405            .ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
406
407        let RequestedRecvConnection {
409            context,
410            connection,
411        } = response_stream;
412
413        let prologue = reader
416            .next()
417            .await
418            .ok_or(error!("Connection closed without a ControlMessge"))??;
419
420        let prologue = match prologue.into_message_type() {
422            TwoPartMessageType::HeaderOnly(header) => {
423                let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
424                    .map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
425                prologue
426            }
427            _ => {
428                panic!("Expected HeaderOnly ControlMessage; internally logic error")
429            }
430        };
431
432        if let Some(error) = &prologue.error {
439            let _ = connection.send(Err(error.clone()));
440            return Err(error!("Received error prologue: {}", error));
441        }
442
443        let (response_tx, response_rx) = mpsc::channel(64);
445
446        if connection
447            .send(Ok(crate::pipeline::network::StreamReceiver {
448                rx: response_rx,
449            }))
450            .is_err()
451        {
452            return Err(error!(
453                "The requester of the stream has been dropped before the connection was established"
454            ));
455        }
456
457        let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
458
459        let send_task = tokio::spawn(network_send_handler(writer, control_rx));
463
464        let recv_task = tokio::spawn(network_receive_handler(
466            reader,
467            response_tx,
468            control_tx,
469            context.clone(),
470        ));
471
472        let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
474
475        monitor_result?;
476        forward_result?;
477
478        Ok(())
479    }
480
481    async fn network_receive_handler(
482        mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
483        response_tx: mpsc::Sender<Bytes>,
484        control_tx: mpsc::Sender<ControlMessage>,
485        context: Arc<dyn AsyncEngineContext>,
486    ) {
487        let mut can_stop = true;
489        loop {
490            tokio::select! {
491                biased;
492
493                _ = response_tx.closed() => {
494                    tracing::trace!("response channel closed before the client finished writing data");
495                    control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
496                    break;
497                }
498
499                _ = context.killed() => {
500                    tracing::trace!("context kill signal received; shutting down");
501                    control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
502                    break;
503                }
504
505                _ = context.stopped(), if can_stop => {
506                    tracing::trace!("context stop signal received; shutting down");
507                    can_stop = false;
508                    control_tx.send(ControlMessage::Stop).await.expect("the control channel should not be closed");
509                }
510
511                msg = framed_reader.next() => {
512                    match msg {
513                        Some(Ok(msg)) => {
514                            let (header, data) = msg.into_parts();
515
516                            if !header.is_empty() {
518                                match process_control_message(header) {
519                                    Ok(ControlAction::Continue) => {}
520                                    Ok(ControlAction::Shutdown) => {
521                                        assert!(data.is_empty(), "received sentinel message with data; this should never happen");
522                                        tracing::trace!("received sentinel message; shutting down");
523                                        break;
524                                    }
525                                    Err(e) => {
526                                        panic!("{:?}", e);
528                                    }
529                                }
530                            }
531
532                            if !data.is_empty()
533                                && let Err(err) = response_tx.send(data).await {
534                                    tracing::debug!("forwarding body/data message to response channel failed: {}", err);
535                                    control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
536                                    break;
537                                };
538                        }
539                        Some(Err(_)) => {
540                            panic!("invalid message issued over socket; this should never happen");
542                        }
543                        None => {
544                            tracing::trace!("tcp stream was closed by client");
550                            break;
551                        }
552                    }
553                }
554
555            }
556        }
557    }
558
559    async fn network_send_handler(
560        socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
561        control_rx: mpsc::Receiver<ControlMessage>,
562    ) {
563        let mut socket_tx = socket_tx;
564        let mut control_rx = control_rx;
565
566        while let Some(control_msg) = control_rx.recv().await {
567            assert_ne!(
568                control_msg,
569                ControlMessage::Sentinel,
570                "received sentinel message; this should never happen"
571            );
572            let bytes =
573                serde_json::to_vec(&control_msg).expect("failed to serialize control message");
574            let message = TwoPartMessage::from_header(bytes.into());
575            match socket_tx.send(message).await {
576                Ok(_) => tracing::debug!("issued control message {control_msg:?} to sender"),
577                Err(_) => {
578                    tracing::debug!("failed to send control message {control_msg:?} to sender")
579                }
580            }
581        }
582
583        let mut inner = socket_tx.into_inner();
584        if let Err(e) = inner.flush().await {
585            tracing::debug!("failed to flush socket: {}", e);
586        }
587        if let Err(e) = inner.shutdown().await {
588            tracing::debug!("failed to shutdown socket: {}", e);
589        }
590    }
591}
592
593enum ControlAction {
594    Continue,
595    Shutdown,
596}
597
598fn process_control_message(message: Bytes) -> Result<ControlAction> {
599    match serde_json::from_slice::<ControlMessage>(&message)? {
600        ControlMessage::Sentinel => {
601            tracing::trace!("sentinel received; shutting down");
604            Ok(ControlAction::Shutdown)
605        }
606        ControlMessage::Kill | ControlMessage::Stop => {
607            anyhow::bail!(
609                "fatal error - unexpected control message received - this should never happen"
610            );
611        }
612    }
613}