dynamo_runtime/pipeline/network/tcp/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9    io::AsyncWriteExt,
10    net::TcpStream,
11    time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
16use crate::engine::AsyncEngineContext;
17use crate::pipeline::network::{
18    ConnectionInfo, ResponseStreamPrologue, StreamSender,
19    codec::{TwoPartCodec, TwoPartMessage},
20    tcp::StreamType,
21};
22use anyhow::{Context, Result, anyhow as error}; // Import SinkExt to use the `send` method
23
24#[allow(dead_code)]
25pub struct TcpClient {
26    worker_id: String,
27}
28
29impl Default for TcpClient {
30    fn default() -> Self {
31        TcpClient {
32            worker_id: uuid::Uuid::new_v4().to_string(),
33        }
34    }
35}
36
37impl TcpClient {
38    pub fn new(worker_id: String) -> Self {
39        TcpClient { worker_id }
40    }
41
42    async fn connect(address: &str) -> std::io::Result<TcpStream> {
43        // try to connect to the address; retry with linear backoff if AddrNotAvailable
44        let backoff = std::time::Duration::from_millis(200);
45        loop {
46            match TcpStream::connect(address).await {
47                Ok(socket) => {
48                    socket.set_nodelay(true)?;
49                    return Ok(socket);
50                }
51                Err(e) => {
52                    if e.kind() == std::io::ErrorKind::AddrNotAvailable {
53                        tracing::warn!("retry warning: failed to connect: {:?}", e);
54                        tokio::time::sleep(backoff).await;
55                    } else {
56                        return Err(e);
57                    }
58                }
59            }
60        }
61    }
62
63    pub async fn create_response_stream(
64        context: Arc<dyn AsyncEngineContext>,
65        info: ConnectionInfo,
66    ) -> Result<StreamSender> {
67        let info =
68            TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
69        tracing::trace!("Creating response stream for {:?}", info);
70
71        if info.stream_type != StreamType::Response {
72            return Err(error!(
73                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
74                info.stream_type
75            ));
76        }
77
78        if info.context != context.id() {
79            return Err(error!(
80                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
81                context.id(),
82                info.context
83            ));
84        }
85
86        let stream = TcpClient::connect(&info.address).await?;
87        let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
88        let (read_half, write_half) = tokio::io::split(stream);
89
90        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
91        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
92
93        // this is a oneshot channel that will be used to signal when the stream is closed
94        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
95        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
96        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
97        // captured by the monitor task
98        let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
99
100        let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
101
102        // transport specific handshake message
103        let handshake = CallHomeHandshake {
104            subject: info.subject.clone(),
105            stream_type: StreamType::Response,
106        };
107
108        let handshake_bytes = match serde_json::to_vec(&handshake) {
109            Ok(hb) => hb,
110            Err(err) => {
111                return Err(error!(
112                    "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
113                ));
114            }
115        };
116        let msg = TwoPartMessage::from_header(handshake_bytes.into());
117
118        // issue the the first tcp handshake message
119        framed_writer
120            .send(msg)
121            .await
122            .map_err(|e| error!("failed to send handshake: {:?}", e))?;
123
124        // set up the channel to send bytes to the transport layer
125        let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
126
127        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
128
129        let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
130
131        let subject = info.subject.clone();
132        tokio::spawn(async move {
133            // await both tasks
134            let (reader, writer) = tokio::join!(reader_task, writer_task);
135
136            match (reader, writer) {
137                (Ok(reader), Ok(writer)) => {
138                    let reader = reader.into_inner();
139
140                    let writer = match writer {
141                        Ok(writer) => writer.into_inner(),
142                        Err(e) => {
143                            tracing::error!("failed to join writer task: {:?}", e);
144                            return Err(e);
145                        }
146                    };
147
148                    let mut stream = reader.unsplit(writer);
149
150                    // await the tcp server to shutdown the socket connection
151                    // set a timeout for the server shutdown
152                    let mut buf = vec![0u8; 1024];
153                    let deadline = Instant::now() + Duration::from_secs(10);
154                    loop {
155                        let n = time::timeout_at(deadline, stream.read(&mut buf))
156                            .await
157                            .inspect_err(|_| {
158                                tracing::debug!("server did not close socket within the deadline");
159                            })?
160                            .inspect_err(|e| {
161                                tracing::debug!("failed to read from stream: {:?}", e);
162                            })?;
163                        if n == 0 {
164                            // Server has closed (FIN)
165                            break;
166                        }
167                    }
168
169                    Ok(())
170                }
171                (Err(reader_err), Ok(_)) => {
172                    tracing::error!(
173                        "reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
174                    );
175                    anyhow::bail!(
176                        "reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
177                    );
178                }
179                (Ok(_), Err(writer_err)) => {
180                    tracing::error!(
181                        "writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
182                    );
183                    anyhow::bail!(
184                        "writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
185                    );
186                }
187                (Err(reader_err), Err(writer_err)) => {
188                    tracing::error!(
189                        "both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
190                    );
191                    anyhow::bail!(
192                        "both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
193                    );
194                }
195            }
196        });
197
198        // set up the prologue for the stream
199        // this might have transport specific metadata in the future
200        let prologue = Some(ResponseStreamPrologue { error: None });
201
202        // create the stream sender
203        let stream_sender = StreamSender {
204            tx: bytes_tx,
205            prologue,
206        };
207
208        Ok(stream_sender)
209    }
210}
211
212async fn handle_reader(
213    framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
214    context: Arc<dyn AsyncEngineContext>,
215    alive_tx: tokio::sync::oneshot::Sender<()>,
216) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
217    let mut framed_reader = framed_reader;
218    let mut alive_tx = alive_tx;
219    loop {
220        tokio::select! {
221            msg = framed_reader.next() => {
222                match msg {
223                    Some(Ok(two_part_msg)) => {
224                        match two_part_msg.optional_parts() {
225                           (Some(bytes), None) => {
226                                let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
227                                    Ok(msg) => msg,
228                                    Err(_) => {
229                                        // TODO(#171) - address fatal errors
230                                        panic!("fatal error - invalid control message detected");
231                                    }
232                                };
233
234                                match msg {
235                                    ControlMessage::Stop => {
236                                        context.stop();
237                                    }
238                                    ControlMessage::Kill => {
239                                        context.kill();
240                                    }
241                                    ControlMessage::Sentinel => {
242                                        // TODO(#171) - address fatal errors
243                                        panic!("received a sentinel message; this should never happen");
244                                    }
245                                }
246                           }
247                           _ => {
248                                panic!("received a non-control message; this should never happen");
249                           }
250                        }
251                    }
252                    Some(Err(e)) => {
253                        // TODO(#171) - address fatal errors
254                        // in this case the binary representation of the message is invalid
255                        panic!("fatal error - failed to decode message from stream; invalid line protocol: {e:?}");
256                    }
257                    None => {
258                        tracing::debug!("tcp stream closed by server");
259                        break;
260                    }
261                }
262            }
263            _ = alive_tx.closed() => {
264                break;
265            }
266        }
267    }
268    framed_reader
269}
270
271async fn handle_writer(
272    mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
273    mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
274    alive_rx: tokio::sync::oneshot::Receiver<()>,
275    context: Arc<dyn AsyncEngineContext>,
276) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
277    // Only send sentinel for normal channel closure
278    let mut send_sentinel = true;
279
280    loop {
281        let msg = tokio::select! {
282            biased;
283
284            _ = context.killed() => {
285                tracing::trace!("context kill signal received; shutting down");
286                send_sentinel = false;
287                break;
288            }
289
290            _ = context.stopped() => {
291                tracing::trace!("context stop signal received; shutting down");
292                send_sentinel = false;
293                break;
294            }
295
296            msg = bytes_rx.recv() => {
297                match msg {
298                    Some(msg) => msg,
299                    None => {
300                        tracing::trace!("response channel closed; shutting down");
301                        break;
302                    }
303                }
304            }
305        };
306
307        if let Err(e) = framed_writer.send(msg).await {
308            tracing::trace!(
309                "failed to send message to network; possible disconnect: {:?}",
310                e
311            );
312            send_sentinel = false;
313            break;
314        }
315    }
316
317    // Send sentinel only on normal closure
318    if send_sentinel {
319        let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
320        let msg = TwoPartMessage::from_header(message.into());
321        framed_writer.send(msg).await?;
322    }
323
324    drop(alive_rx);
325    Ok(framed_writer)
326}