dynamo_runtime/pipeline/network/tcp/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::sync::Arc;
17
18use futures::{SinkExt, StreamExt};
19use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
20use tokio::{
21    io::AsyncWriteExt,
22    net::TcpStream,
23    time::{self, Duration, Instant},
24};
25use tokio_util::codec::{FramedRead, FramedWrite};
26
27use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
28use crate::engine::AsyncEngineContext;
29use crate::pipeline::network::{
30    ConnectionInfo, ResponseStreamPrologue, StreamSender,
31    codec::{TwoPartCodec, TwoPartMessage},
32    tcp::StreamType,
33};
34use crate::{ErrorContext, Result, error}; // Import SinkExt to use the `send` method
35
36#[allow(dead_code)]
37pub struct TcpClient {
38    worker_id: String,
39}
40
41impl Default for TcpClient {
42    fn default() -> Self {
43        TcpClient {
44            worker_id: uuid::Uuid::new_v4().to_string(),
45        }
46    }
47}
48
49impl TcpClient {
50    pub fn new(worker_id: String) -> Self {
51        TcpClient { worker_id }
52    }
53
54    async fn connect(address: &str) -> std::io::Result<TcpStream> {
55        // try to connect to the address; retry with linear backoff if AddrNotAvailable
56        let backoff = std::time::Duration::from_millis(200);
57        loop {
58            match TcpStream::connect(address).await {
59                Ok(socket) => {
60                    socket.set_nodelay(true)?;
61                    return Ok(socket);
62                }
63                Err(e) => {
64                    if e.kind() == std::io::ErrorKind::AddrNotAvailable {
65                        tracing::warn!("retry warning: failed to connect: {:?}", e);
66                        tokio::time::sleep(backoff).await;
67                    } else {
68                        return Err(e);
69                    }
70                }
71            }
72        }
73    }
74
75    pub async fn create_response_stream(
76        context: Arc<dyn AsyncEngineContext>,
77        info: ConnectionInfo,
78    ) -> Result<StreamSender> {
79        let info =
80            TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
81        tracing::trace!("Creating response stream for {:?}", info);
82
83        if info.stream_type != StreamType::Response {
84            return Err(error!(
85                "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
86                info.stream_type
87            ));
88        }
89
90        if info.context != context.id() {
91            return Err(error!(
92                "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
93                context.id(),
94                info.context
95            ));
96        }
97
98        let stream = TcpClient::connect(&info.address).await?;
99        let (read_half, write_half) = tokio::io::split(stream);
100
101        let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
102        let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
103
104        // this is a oneshot channel that will be used to signal when the stream is closed
105        // when the stream sender is dropped, the bytes_rx will be closed and the forwarder task will exit
106        // the forwarder task will capture the alive_rx half of the oneshot channel; this will close the alive channel
107        // so the holder of the alive_tx half will be notified that the stream is closed; the alive_tx channel will be
108        // captured by the monitor task
109        let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
110
111        let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
112
113        // transport specific handshake message
114        let handshake = CallHomeHandshake {
115            subject: info.subject,
116            stream_type: StreamType::Response,
117        };
118
119        let handshake_bytes = match serde_json::to_vec(&handshake) {
120            Ok(hb) => hb,
121            Err(err) => {
122                return Err(error!(
123                    "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
124                ));
125            }
126        };
127        let msg = TwoPartMessage::from_header(handshake_bytes.into());
128
129        // issue the the first tcp handshake message
130        framed_writer
131            .send(msg)
132            .await
133            .map_err(|e| error!("failed to send handshake: {:?}", e))?;
134
135        // set up the channel to send bytes to the transport layer
136        let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
137
138        // forwards the bytes send from this stream to the transport layer; hold the alive_rx half of the oneshot channel
139
140        let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
141
142        tokio::spawn(async move {
143            // await both tasks
144            let (reader, writer) = tokio::join!(reader_task, writer_task);
145
146            match (reader, writer) {
147                (Ok(reader), Ok(writer)) => {
148                    let reader = reader.into_inner();
149
150                    let writer = match writer {
151                        Ok(writer) => writer.into_inner(),
152                        Err(e) => {
153                            tracing::error!("failed to join writer task: {:?}", e);
154                            return Err(e);
155                        }
156                    };
157
158                    let mut stream = reader.unsplit(writer);
159
160                    // await the tcp server to shutdown the socket connection
161                    // set a timeout for the server shutdown
162                    let mut buf = vec![0u8; 1024];
163                    let deadline = Instant::now() + Duration::from_secs(10);
164                    loop {
165                        let n = time::timeout_at(deadline, stream.read(&mut buf))
166                            .await
167                            .inspect_err(|_| {
168                                tracing::debug!("server did not close socket within the deadline");
169                            })?
170                            .inspect_err(|e| {
171                                tracing::debug!("failed to read from stream: {:?}", e);
172                            })?;
173                        if n == 0 {
174                            // Server has closed (FIN)
175                            break;
176                        }
177                    }
178
179                    Ok(())
180                }
181                _ => {
182                    tracing::error!("failed to join reader and writer tasks");
183                    anyhow::bail!("failed to join reader and writer tasks");
184                }
185            }
186        });
187
188        // set up the prologue for the stream
189        // this might have transport specific metadata in the future
190        let prologue = Some(ResponseStreamPrologue { error: None });
191
192        // create the stream sender
193        let stream_sender = StreamSender {
194            tx: bytes_tx,
195            prologue,
196        };
197
198        Ok(stream_sender)
199    }
200}
201
202async fn handle_reader(
203    framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
204    context: Arc<dyn AsyncEngineContext>,
205    alive_tx: tokio::sync::oneshot::Sender<()>,
206) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
207    let mut framed_reader = framed_reader;
208    let mut alive_tx = alive_tx;
209    loop {
210        tokio::select! {
211            msg = framed_reader.next() => {
212                match msg {
213                    Some(Ok(two_part_msg)) => {
214                        match two_part_msg.optional_parts() {
215                           (Some(bytes), None) => {
216                                let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
217                                    Ok(msg) => msg,
218                                    Err(_) => {
219                                        // TODO(#171) - address fatal errors
220                                        panic!("fatal error - invalid control message detected");
221                                    }
222                                };
223
224                                match msg {
225                                    ControlMessage::Stop => {
226                                        context.stop();
227                                    }
228                                    ControlMessage::Kill => {
229                                        context.kill();
230                                    }
231                                    ControlMessage::Sentinel => {
232                                        // TODO(#171) - address fatal errors
233                                        panic!("received a sentinel message; this should never happen");
234                                    }
235                                }
236                           }
237                           _ => {
238                                panic!("received a non-control message; this should never happen");
239                           }
240                        }
241                    }
242                    Some(Err(_)) => {
243                        // TODO(#171) - address fatal errors
244                        // in this case the binary representation of the message is invalid
245                        panic!("fatal error - failed to decode message from stream; invalid line protocol");
246                    }
247                    None => {
248                        tracing::debug!("tcp stream closed by server");
249                        break;
250                    }
251                }
252            }
253            _ = alive_tx.closed() => {
254                break;
255            }
256        }
257    }
258    framed_reader
259}
260
261async fn handle_writer(
262    mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
263    mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
264    alive_rx: tokio::sync::oneshot::Receiver<()>,
265    context: Arc<dyn AsyncEngineContext>,
266) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
267    loop {
268        let msg = tokio::select! {
269            biased;
270
271            _ = context.killed() => {
272                tracing::trace!("context kill signal received; shutting down");
273                break;
274            }
275
276            msg = bytes_rx.recv() => {
277                match msg {
278                    Some(msg) => msg,
279                    None => {
280                        tracing::trace!("response channel closed; shutting down");
281                        break;
282                    }
283                }
284            }
285        };
286
287        if let Err(e) = framed_writer.send(msg).await {
288            tracing::trace!(
289                "failed to send message to network; possible disconnect: {:?}",
290                e
291            );
292            break;
293        }
294    }
295
296    // send sentinel message
297    let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
298    let msg = TwoPartMessage::from_header(message.into());
299    framed_writer.send(msg).await?;
300
301    drop(alive_rx);
302    Ok(framed_writer)
303}