dynamo_runtime/pipeline/network/tcp/
server.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use core::panic;
17use socket2::{Domain, SockAddr, Socket, Type};
18use std::{
19    collections::HashMap,
20    net::{SocketAddr, TcpListener},
21    os::fd::{AsFd, FromRawFd},
22    sync::Arc,
23};
24use tokio::sync::Mutex;
25
26use bytes::Bytes;
27use derive_builder::Builder;
28use futures::{SinkExt, StreamExt};
29use local_ip_address::{list_afinet_netifas, local_ip};
30use serde::{Deserialize, Serialize};
31use tokio::{
32    io::AsyncWriteExt,
33    sync::{mpsc, oneshot},
34    time,
35};
36use tokio_util::codec::{FramedRead, FramedWrite};
37
38use super::{
39    CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
40    StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
41};
42use crate::engine::AsyncEngineContext;
43use crate::pipeline::{
44    network::{
45        codec::{TwoPartMessage, TwoPartMessageType},
46        tcp::StreamType,
47        ResponseService, ResponseStreamPrologue,
48    },
49    PipelineError,
50};
51use crate::{error, ErrorContext, Result};
52
53#[allow(dead_code)]
54type ResponseType = TwoPartMessage;
55
56#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
57pub struct ServerOptions {
58    #[builder(default = "0")]
59    pub port: u16,
60
61    #[builder(default)]
62    pub interface: Option<String>,
63}
64
65impl ServerOptions {
66    pub fn builder() -> ServerOptionsBuilder {
67        ServerOptionsBuilder::default()
68    }
69}
70
71/// A [`TcpStreamServer`] is a TCP service that listens on a port for incoming response connections.
72/// A Response connection is a connection that is established by a client with the intention of sending
73/// specific data back to the server.
74pub struct TcpStreamServer {
75    local_ip: String,
76    local_port: u16,
77    state: Arc<Mutex<State>>,
78}
79
80// pub struct TcpStreamReceiver {
81//     address: TcpStreamConnectionInfo,
82//     state: Arc<Mutex<State>>,
83//     rx: mpsc::Receiver<ResponseType>,
84// }
85
86#[allow(dead_code)]
87struct RequestedSendConnection {
88    context: Arc<dyn AsyncEngineContext>,
89    connection: oneshot::Sender<Result<StreamSender, String>>,
90}
91
92struct RequestedRecvConnection {
93    context: Arc<dyn AsyncEngineContext>,
94    connection: oneshot::Sender<Result<StreamReceiver, String>>,
95}
96
97// /// When registering a new TcpStream on the server, the registration method will return a [`Connections`] object.
98// /// This [`Connections`] object will have two [`oneshot::Receiver`] objects, one for the [`TcpStreamSender`] and one for the [`TcpStreamReceiver`].
99// /// The [`Connections`] object can be awaited to get the [`TcpStreamSender`] and [`TcpStreamReceiver`] objects; these objects will
100// /// be made available when the matching Client has connected to the server.
101// pub struct Connections {
102//     pub address: TcpStreamConnectionInfo,
103
104//     /// The [`oneshot::Receiver`] for the [`TcpStreamSender`]. Awaiting this object will return the [`TcpStreamSender`] object once
105//     /// the client has connected to the server.
106//     pub sender: Option<oneshot::Receiver<StreamSender>>,
107
108//     /// The [`oneshot::Receiver`] for the [`TcpStreamReceiver`]. Awaiting this object will return the [`TcpStreamReceiver`] object once
109//     /// the client has connected to the server.
110//     pub receiver: Option<oneshot::Receiver<StreamReceiver>>,
111// }
112
113#[derive(Default)]
114struct State {
115    tx_subjects: HashMap<String, RequestedSendConnection>,
116    rx_subjects: HashMap<String, RequestedRecvConnection>,
117    handle: Option<tokio::task::JoinHandle<Result<()>>>,
118}
119
120impl TcpStreamServer {
121    pub fn options_builder() -> ServerOptionsBuilder {
122        ServerOptionsBuilder::default()
123    }
124
125    pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
126        let local_ip = match options.interface {
127            Some(interface) => {
128                let interfaces: HashMap<String, std::net::IpAddr> =
129                    list_afinet_netifas()?.into_iter().collect();
130
131                interfaces
132                    .get(&interface)
133                    .ok_or(PipelineError::Generic(format!(
134                        "Interface not found: {}",
135                        interface
136                    )))?
137                    .to_string()
138            }
139            None => local_ip().unwrap().to_string(),
140        };
141
142        let state = Arc::new(Mutex::new(State::default()));
143
144        let local_port = Self::start(local_ip.clone(), options.port, state.clone())
145            .await
146            .map_err(|e| {
147                PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
148            })?;
149
150        tracing::info!("tcp transport service on {}:{}", local_ip, local_port);
151
152        Ok(Arc::new(Self {
153            local_ip,
154            local_port,
155            state,
156        }))
157    }
158
159    #[allow(clippy::await_holding_lock)]
160    async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
161        let addr = format!("{}:{}", local_ip, local_port);
162        let state_clone = state.clone();
163        let mut guard = state.lock().await;
164        if guard.handle.is_some() {
165            panic!("TcpStreamServer already started");
166        }
167        let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
168        let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
169        guard.handle = Some(handle);
170        drop(guard);
171        let local_port = ready_rx.await??;
172        Ok(local_port)
173    }
174}
175
176// todo - possible rename ResponseService to ResponseServer
177#[async_trait::async_trait]
178impl ResponseService for TcpStreamServer {
179    /// Register a new subject and sender with the response subscriber
180    /// Produces an RAII object that will deregister the subject when dropped
181    ///
182    /// we need to register both data in and data out entries
183    /// there might be forward pipeline that want to consume the data out stream
184    /// and there might be a response stream that wants to consume the data in stream
185    /// on registration, we need to specific if we want data-in, data-out or both
186    /// this will map to the type of service that is runniing, i.e. Single or Many In //
187    /// Single or Many Out
188    ///
189    /// todo(ryan) - return a connection object that can be awaited. when successfully connected,
190    /// can ask for the sender and receiver
191    ///
192    /// OR
193    ///
194    /// we make it into register sender and register receiver, both would return a connection object
195    /// and when a connection is established, we'd get the respective sender or receiver
196    ///
197    /// the registration probably needs to be done in one-go, so we should use a builder object for
198    /// requesting a receiver and optional sender
199    async fn register(&self, options: StreamOptions) -> PendingConnections {
200        // oneshot channels to pass back the sender and receiver objects
201
202        let address = format!("{}:{}", self.local_ip, self.local_port);
203        tracing::debug!("Registering new TcpStream on {}", address);
204
205        let send_stream = if options.enable_request_stream {
206            let sender_subject = uuid::Uuid::new_v4().to_string();
207
208            let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
209
210            let connection_info = RequestedSendConnection {
211                context: options.context.clone(),
212                connection: pending_sender_tx,
213            };
214
215            let mut state = self.state.lock().await;
216            state
217                .tx_subjects
218                .insert(sender_subject.clone(), connection_info);
219
220            let registered_stream = RegisteredStream {
221                connection_info: TcpStreamConnectionInfo {
222                    address: address.clone(),
223                    subject: sender_subject.clone(),
224                    context: options.context.id().to_string(),
225                    stream_type: StreamType::Request,
226                }
227                .into(),
228                stream_provider: pending_sender_rx,
229            };
230
231            Some(registered_stream)
232        } else {
233            None
234        };
235
236        let recv_stream = if options.enable_response_stream {
237            let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
238            let receiver_subject = uuid::Uuid::new_v4().to_string();
239
240            let connection_info = RequestedRecvConnection {
241                context: options.context.clone(),
242                connection: pending_recver_tx,
243            };
244
245            let mut state = self.state.lock().await;
246            state
247                .rx_subjects
248                .insert(receiver_subject.clone(), connection_info);
249
250            let registered_stream = RegisteredStream {
251                connection_info: TcpStreamConnectionInfo {
252                    address: address.clone(),
253                    subject: receiver_subject.clone(),
254                    context: options.context.id().to_string(),
255                    stream_type: StreamType::Response,
256                }
257                .into(),
258                stream_provider: pending_recver_rx,
259            };
260
261            Some(registered_stream)
262        } else {
263            None
264        };
265
266        PendingConnections {
267            send_stream,
268            recv_stream,
269        }
270    }
271}
272
273// this method listens on a tcp port for incoming connections
274// new connections are expected to send a protocol specific handshake
275// for us to determine the subject they are interested in, in this case,
276// we expect the first message to be [`FirstMessage`] from which we find
277// the sender, then we spawn a task to forward all bytes from the tcp stream
278// to the sender
279async fn tcp_listener(
280    addr: String,
281    state: Arc<Mutex<State>>,
282    read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
283) -> Result<()> {
284    let listener = tokio::net::TcpListener::bind(&addr)
285        .await
286        .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
287
288    let listener = match listener {
289        Ok(listener) => {
290            let addr = listener
291                .local_addr()
292                .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
293                .unwrap();
294
295            read_tx
296                .send(Ok(addr.port()))
297                .expect("Failed to send ready signal");
298
299            listener
300        }
301        Err(e) => {
302            read_tx.send(Err(e)).expect("Failed to send ready signal");
303            return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
304        }
305    };
306
307    loop {
308        // todo - add instrumentation
309        // todo - add counter for all accepted connections
310        // todo - add gauge for all inflight connections
311        // todo - add counter for incoming bytes
312        // todo - add counter for outgoing bytes
313        let (stream, _addr) = match listener.accept().await {
314            Ok((stream, _addr)) => (stream, _addr),
315            Err(e) => {
316                // the client should retry, so we don't need to abort
317                tracing::warn!("failed to accept tcp connection: {}", e);
318                eprintln!("failed to accept tcp connection: {}", e);
319                continue;
320            }
321        };
322
323        match stream.set_nodelay(true) {
324            Ok(_) => (),
325            Err(e) => {
326                tracing::warn!("failed to set tcp stream to nodelay: {}", e);
327            }
328        }
329
330        match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
331            Ok(_) => (),
332            Err(e) => {
333                tracing::warn!("failed to set tcp stream to linger: {}", e);
334            }
335        }
336
337        tokio::spawn(handle_connection(stream, state.clone()));
338    }
339
340    // #[instrument(level = "trace"), skip(state)]
341    // todo - clone before spawn and trace process_stream
342    async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
343        let result = process_stream(stream, state).await;
344        match result {
345            Ok(_) => tracing::trace!("successfully processed tcp connection"),
346            Err(e) => {
347                tracing::warn!("failed to handle tcp connection: {}", e);
348                #[cfg(debug_assertions)]
349                eprintln!("failed to handle tcp connection: {}", e);
350            }
351        }
352    }
353
354    /// This method is responsible for the internal tcp stream handshake
355    /// The handshake will specialize the stream as a request/sender or response/receiver stream
356    async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
357        // split the socket in to a reader and writer
358        let (read_half, write_half) = tokio::io::split(stream);
359
360        // attach the codec to the reader and writer to get framed readers and writers
361        let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
362        let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
363
364        // the internal tcp [`CallHomeHandshake`] connects the socket to the requester
365        // here we await this first message as a raw bytes two part message
366        let first_message = framed_reader
367            .next()
368            .await
369            .ok_or(error!("Connection closed without a ControlMessage"))??;
370
371        // we await on the raw bytes which should come in as a header only message
372        // todo - improve error handling - check for no data
373        let handshake: CallHomeHandshake = match first_message.header() {
374            Some(header) => serde_json::from_slice(header).map_err(|e| {
375                error!(
376                    "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
377                )
378            })?,
379            None => {
380                return Err(error!("Expected ControlMessage, got DataMessage"));
381            }
382        };
383
384        // branch here to handle sender stream or receiver stream
385        match handshake.stream_type {
386            StreamType::Request => process_request_stream().await,
387            StreamType::Response => {
388                process_response_stream(handshake.subject, state, framed_reader, framed_writer)
389                    .await
390            }
391        }
392    }
393
394    async fn process_request_stream() -> Result<()> {
395        Ok(())
396    }
397
398    async fn process_response_stream(
399        subject: String,
400        state: Arc<Mutex<State>>,
401        mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
402        writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
403    ) -> Result<()> {
404        let response_stream = state
405            .lock().await
406            .rx_subjects
407            .remove(&subject)
408            .ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
409
410        // unwrap response_stream
411        let RequestedRecvConnection {
412            context,
413            connection,
414        } = response_stream;
415
416        // the [`Prologue`]
417        // there must be a second control message it indicate the other segment's generate method was successful
418        let prologue = reader
419            .next()
420            .await
421            .ok_or(error!("Connection closed without a ControlMessge"))??;
422
423        // deserialize prologue
424        let prologue = match prologue.into_message_type() {
425            TwoPartMessageType::HeaderOnly(header) => {
426                let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
427                    .map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
428                prologue
429            }
430            _ => {
431                panic!("Expected HeaderOnly ControlMessage; internally logic error")
432            }
433        };
434
435        // await the control message of GTG or Error, if error, then connection.send(Err(String)), which should fail the
436        // generate call chain
437        //
438        // note: this second control message might be delayed, but the expensive part of setting up the connection
439        // is both complete and ready for data flow; awaiting here is not a performance hit or problem and it allows
440        // us to trace the initial setup time vs the time to prologue
441        if let Some(error) = &prologue.error {
442            let _ = connection.send(Err(error.clone()));
443            return Err(error!("Received error prologue: {}", error));
444        }
445
446        // we need to know the buffer size from the registration options; add this to the RequestRecvConnection object
447        let (response_tx, response_rx) = mpsc::channel(64);
448
449        if connection
450            .send(Ok(crate::pipeline::network::StreamReceiver {
451                rx: response_rx,
452            }))
453            .is_err()
454        {
455            return Err(error!("The requester of the stream has been dropped before the connection was established"));
456        }
457
458        let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
459
460        // sender task
461        // issues control messages to the sender and when finished shuts down the socket
462        // this should be the last task to finish and must
463        let send_task = tokio::spawn(network_send_handler(writer, control_rx));
464
465        // forward task
466        let recv_task = tokio::spawn(network_receive_handler(
467            reader,
468            response_tx,
469            control_tx,
470            context.clone(),
471        ));
472
473        // check the results of each of the tasks
474        let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
475
476        monitor_result?;
477        forward_result?;
478
479        Ok(())
480    }
481
482    async fn network_receive_handler(
483        mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
484        response_tx: mpsc::Sender<Bytes>,
485        control_tx: mpsc::Sender<ControlMessage>,
486        context: Arc<dyn AsyncEngineContext>,
487    ) {
488        // loop over reading the tcp stream and checking if the writer is closed
489        let mut can_stop = true;
490        loop {
491            tokio::select! {
492                biased;
493
494                _ = response_tx.closed() => {
495                    tracing::trace!("response channel closed before the client finished writing data");
496                    control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
497                    break;
498                }
499
500                _ = context.killed() => {
501                    tracing::trace!("context kill signal received; shutting down");
502                    control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
503                    break;
504                }
505
506                _ = context.stopped(), if can_stop => {
507                    tracing::trace!("context stop signal received; shutting down");
508                    can_stop = false;
509                    control_tx.send(ControlMessage::Stop).await.expect("the control channel should not be closed");
510                }
511
512                msg = framed_reader.next() => {
513                    match msg {
514                        Some(Ok(msg)) => {
515                            let (header, data) = msg.into_parts();
516
517                            // received a control message
518                            if !header.is_empty() {
519                                match process_control_message(header) {
520                                    Ok(ControlAction::Continue) => {}
521                                    Ok(ControlAction::Shutdown) => {
522                                        assert!(data.is_empty(), "received sentinel message with data; this should never happen");
523                                        tracing::trace!("received sentinel message; shutting down");
524                                        break;
525                                    }
526                                    Err(e) => {
527                                        // TODO(#171) - address fatal errors
528                                        panic!("{:?}", e);
529                                    }
530                                }
531                            }
532
533                            if !data.is_empty() {
534                                if let Err(err) = response_tx.send(data).await {
535                                    tracing::debug!("forwarding body/data message to response channel failed: {}", err);
536                                    control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
537                                    break;
538                                };
539                            }
540                        }
541                        Some(Err(_)) => {
542                            // TODO(#171) - address fatal errors
543                            panic!("invalid message issued over socket; this should never happen");
544                        }
545                        None => {
546                            // this is allowed but we try to avoid it
547                            // the logic is that the client will tell us when its is done and the server
548                            // will close the connection naturally when the sentinel message is received
549                            // the client closing early represents a transport error outside the control of the
550                            // transport library
551                            tracing::trace!("tcp stream was closed by client");
552                            break;
553                        }
554                    }
555                }
556
557            }
558        }
559    }
560
561    async fn network_send_handler(
562        socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
563        control_rx: mpsc::Receiver<ControlMessage>,
564    ) {
565        let mut socket_tx = socket_tx;
566        let mut control_rx = control_rx;
567
568        while let Some(control_msg) = control_rx.recv().await {
569            assert_ne!(
570                control_msg,
571                ControlMessage::Sentinel,
572                "received sentinel message; this should never happen"
573            );
574            let bytes =
575                serde_json::to_vec(&control_msg).expect("failed to serialize control message");
576            let message = TwoPartMessage::from_header(bytes.into());
577            match socket_tx.send(message).await {
578                Ok(_) => tracing::debug!("issued control message {control_msg:?} to sender"),
579                Err(_) => {
580                    tracing::debug!("failed to send control message {control_msg:?} to sender")
581                }
582            }
583        }
584
585        let mut inner = socket_tx.into_inner();
586        if let Err(e) = inner.flush().await {
587            tracing::debug!("failed to flush socket: {}", e);
588        }
589        if let Err(e) = inner.shutdown().await {
590            tracing::debug!("failed to shutdown socket: {}", e);
591        }
592    }
593}
594
595enum ControlAction {
596    Continue,
597    Shutdown,
598}
599
600fn process_control_message(message: Bytes) -> Result<ControlAction> {
601    match serde_json::from_slice::<ControlMessage>(&message)? {
602        ControlMessage::Sentinel => {
603            // the client issued a sentinel message
604            // it has finished writing data and is now awaiting the server to close the connection
605            tracing::trace!("sentinel received; shutting down");
606            Ok(ControlAction::Shutdown)
607        }
608        ControlMessage::Kill | ControlMessage::Stop => {
609            // TODO(#171) - address fatal errors
610            anyhow::bail!(
611                "fatal error - unexpected control message received - this should never happen"
612            );
613        }
614    }
615}