dynamo_runtime/pipeline/network/tcp/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9    io::AsyncWriteExt,
10    net::TcpStream,
11    time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
16use crate::engine::AsyncEngineContext;
17use crate::pipeline::network::{
18    ConnectionInfo, ResponseStreamPrologue, StreamSender,
19    codec::{TwoPartCodec, TwoPartMessage},
20    tcp::StreamType,
21};
22use crate::{ErrorContext, Result, error}; // Import SinkExt to use the `send` method
23
24#[allow(dead_code)]
25pub struct TcpClient {
26    worker_id: String,
27}
28
29impl Default for TcpClient {
30    fn default() -> Self {
31        TcpClient {
32            worker_id: uuid::Uuid::new_v4().to_string(),
33        }
34    }
35}
36
37impl TcpClient {
38    pub fn new(worker_id: String) -> Self {
39        TcpClient { worker_id }
40    }
41
42    async fn connect(address: &str) -> std::io::Result<TcpStream> {
43        // try to connect to the address; retry with linear backoff if AddrNotAvailable
44        let backoff = std::time::Duration::from_millis(200);
45        loop {
46            match TcpStream::connect(address).await {
47                Ok(socket) => {
48                    socket.set_nodelay(true)?;
49                    return Ok(socket);
50                }
51                Err(e) => {
52                    if e.kind() == std::io::ErrorKind::AddrNotAvailable {
53                        tracing::warn!("retry warning: failed to connect: {:?}", e);
54                        tokio::time::sleep(backoff).await;
55                    } else {
56                        return Err(e);
57                    }
58                }
59            }
60        }
61    }
62
63    pub async fn create_response_stream(
64        context: Arc<dyn AsyncEngineContext>,
65        info: ConnectionInfo,
66    ) -> Result<StreamSender> {
67        let info =
68            TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
69        tracing::trace!("Creating response stream for {:?}", info);
70
71        if info.stream_type != StreamType::Response {
72            return Err(error!(
73                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
74                info.stream_type
75            ));
76        }
77
78        if info.context != context.id() {
79            return Err(error!(
80                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
81                context.id(),
82                info.context
83            ));
84        }
85
86        let stream = TcpClient::connect(&info.address).await?;
87        let (read_half, write_half) = tokio::io::split(stream);
88
89        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
90        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
91
92        // this is a oneshot channel that will be used to signal when the stream is closed
93        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
94        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
95        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
96        // captured by the monitor task
97        let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
98
99        let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
100
101        // transport specific handshake message
102        let handshake = CallHomeHandshake {
103            subject: info.subject,
104            stream_type: StreamType::Response,
105        };
106
107        let handshake_bytes = match serde_json::to_vec(&handshake) {
108            Ok(hb) => hb,
109            Err(err) => {
110                return Err(error!(
111                    "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
112                ));
113            }
114        };
115        let msg = TwoPartMessage::from_header(handshake_bytes.into());
116
117        // issue the the first tcp handshake message
118        framed_writer
119            .send(msg)
120            .await
121            .map_err(|e| error!("failed to send handshake: {:?}", e))?;
122
123        // set up the channel to send bytes to the transport layer
124        let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
125
126        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
127
128        let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
129
130        tokio::spawn(async move {
131            // await both tasks
132            let (reader, writer) = tokio::join!(reader_task, writer_task);
133
134            match (reader, writer) {
135                (Ok(reader), Ok(writer)) => {
136                    let reader = reader.into_inner();
137
138                    let writer = match writer {
139                        Ok(writer) => writer.into_inner(),
140                        Err(e) => {
141                            tracing::error!("failed to join writer task: {:?}", e);
142                            return Err(e);
143                        }
144                    };
145
146                    let mut stream = reader.unsplit(writer);
147
148                    // await the tcp server to shutdown the socket connection
149                    // set a timeout for the server shutdown
150                    let mut buf = vec![0u8; 1024];
151                    let deadline = Instant::now() + Duration::from_secs(10);
152                    loop {
153                        let n = time::timeout_at(deadline, stream.read(&mut buf))
154                            .await
155                            .inspect_err(|_| {
156                                tracing::debug!("server did not close socket within the deadline");
157                            })?
158                            .inspect_err(|e| {
159                                tracing::debug!("failed to read from stream: {:?}", e);
160                            })?;
161                        if n == 0 {
162                            // Server has closed (FIN)
163                            break;
164                        }
165                    }
166
167                    Ok(())
168                }
169                _ => {
170                    tracing::error!("failed to join reader and writer tasks");
171                    anyhow::bail!("failed to join reader and writer tasks");
172                }
173            }
174        });
175
176        // set up the prologue for the stream
177        // this might have transport specific metadata in the future
178        let prologue = Some(ResponseStreamPrologue { error: None });
179
180        // create the stream sender
181        let stream_sender = StreamSender {
182            tx: bytes_tx,
183            prologue,
184        };
185
186        Ok(stream_sender)
187    }
188}
189
190async fn handle_reader(
191    framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
192    context: Arc<dyn AsyncEngineContext>,
193    alive_tx: tokio::sync::oneshot::Sender<()>,
194) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
195    let mut framed_reader = framed_reader;
196    let mut alive_tx = alive_tx;
197    loop {
198        tokio::select! {
199            msg = framed_reader.next() => {
200                match msg {
201                    Some(Ok(two_part_msg)) => {
202                        match two_part_msg.optional_parts() {
203                           (Some(bytes), None) => {
204                                let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
205                                    Ok(msg) => msg,
206                                    Err(_) => {
207                                        // TODO(#171) - address fatal errors
208                                        panic!("fatal error - invalid control message detected");
209                                    }
210                                };
211
212                                match msg {
213                                    ControlMessage::Stop => {
214                                        context.stop();
215                                    }
216                                    ControlMessage::Kill => {
217                                        context.kill();
218                                    }
219                                    ControlMessage::Sentinel => {
220                                        // TODO(#171) - address fatal errors
221                                        panic!("received a sentinel message; this should never happen");
222                                    }
223                                }
224                           }
225                           _ => {
226                                panic!("received a non-control message; this should never happen");
227                           }
228                        }
229                    }
230                    Some(Err(_)) => {
231                        // TODO(#171) - address fatal errors
232                        // in this case the binary representation of the message is invalid
233                        panic!("fatal error - failed to decode message from stream; invalid line protocol");
234                    }
235                    None => {
236                        tracing::debug!("tcp stream closed by server");
237                        break;
238                    }
239                }
240            }
241            _ = alive_tx.closed() => {
242                break;
243            }
244        }
245    }
246    framed_reader
247}
248
249async fn handle_writer(
250    mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
251    mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
252    alive_rx: tokio::sync::oneshot::Receiver<()>,
253    context: Arc<dyn AsyncEngineContext>,
254) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
255    loop {
256        let msg = tokio::select! {
257            biased;
258
259            _ = context.killed() => {
260                tracing::trace!("context kill signal received; shutting down");
261                break;
262            }
263
264            _ = context.stopped() => {
265                tracing::trace!("context stop signal received; shutting down");
266                break;
267            }
268
269            msg = bytes_rx.recv() => {
270                match msg {
271                    Some(msg) => msg,
272                    None => {
273                        tracing::trace!("response channel closed; shutting down");
274                        break;
275                    }
276                }
277            }
278        };
279
280        if let Err(e) = framed_writer.send(msg).await {
281            tracing::trace!(
282                "failed to send message to network; possible disconnect: {:?}",
283                e
284            );
285            break;
286        }
287    }
288
289    // send sentinel message
290    let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
291    let msg = TwoPartMessage::from_header(message.into());
292    framed_writer.send(msg).await?;
293
294    drop(alive_rx);
295    Ok(framed_writer)
296}