dynamo_runtime/pipeline/network/tcp/
client.rs1use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9 io::AsyncWriteExt,
10 net::TcpStream,
11 time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
16use crate::engine::AsyncEngineContext;
17use crate::pipeline::network::{
18 ConnectionInfo, ResponseStreamPrologue, StreamSender,
19 codec::{TwoPartCodec, TwoPartMessage},
20 tcp::StreamType,
21};
22use anyhow::{Context, Result, anyhow as error}; #[allow(dead_code)]
25pub struct TcpClient {
26 worker_id: String,
27}
28
29impl Default for TcpClient {
30 fn default() -> Self {
31 TcpClient {
32 worker_id: uuid::Uuid::new_v4().to_string(),
33 }
34 }
35}
36
37impl TcpClient {
38 pub fn new(worker_id: String) -> Self {
39 TcpClient { worker_id }
40 }
41
42 async fn connect(address: &str) -> std::io::Result<TcpStream> {
43 let backoff = std::time::Duration::from_millis(200);
45 loop {
46 match TcpStream::connect(address).await {
47 Ok(socket) => {
48 socket.set_nodelay(true)?;
49 return Ok(socket);
50 }
51 Err(e) => {
52 if e.kind() == std::io::ErrorKind::AddrNotAvailable {
53 tracing::warn!("retry warning: failed to connect: {:?}", e);
54 tokio::time::sleep(backoff).await;
55 } else {
56 return Err(e);
57 }
58 }
59 }
60 }
61 }
62
63 pub async fn create_response_stream(
64 context: Arc<dyn AsyncEngineContext>,
65 info: ConnectionInfo,
66 ) -> Result<StreamSender> {
67 let info =
68 TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
69 tracing::trace!("Creating response stream for {:?}", info);
70
71 if info.stream_type != StreamType::Response {
72 return Err(error!(
73 "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
74 info.stream_type
75 ));
76 }
77
78 if info.context != context.id() {
79 return Err(error!(
80 "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
81 context.id(),
82 info.context
83 ));
84 }
85
86 let stream = TcpClient::connect(&info.address).await?;
87 let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
88 let (read_half, write_half) = tokio::io::split(stream);
89
90 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
91 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
92
93 let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
99
100 let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
101
102 let handshake = CallHomeHandshake {
104 subject: info.subject.clone(),
105 stream_type: StreamType::Response,
106 };
107
108 let handshake_bytes = match serde_json::to_vec(&handshake) {
109 Ok(hb) => hb,
110 Err(err) => {
111 return Err(error!(
112 "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
113 ));
114 }
115 };
116 let msg = TwoPartMessage::from_header(handshake_bytes.into());
117
118 framed_writer
120 .send(msg)
121 .await
122 .map_err(|e| error!("failed to send handshake: {:?}", e))?;
123
124 let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
126
127 let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
130
131 let subject = info.subject.clone();
132 tokio::spawn(async move {
133 let (reader, writer) = tokio::join!(reader_task, writer_task);
135
136 match (reader, writer) {
137 (Ok(reader), Ok(writer)) => {
138 let reader = reader.into_inner();
139
140 let writer = match writer {
141 Ok(writer) => writer.into_inner(),
142 Err(e) => {
143 tracing::error!("failed to join writer task: {:?}", e);
144 return Err(e);
145 }
146 };
147
148 let mut stream = reader.unsplit(writer);
149
150 let mut buf = vec![0u8; 1024];
153 let deadline = Instant::now() + Duration::from_secs(10);
154 loop {
155 let n = time::timeout_at(deadline, stream.read(&mut buf))
156 .await
157 .inspect_err(|_| {
158 tracing::debug!("server did not close socket within the deadline");
159 })?
160 .inspect_err(|e| {
161 tracing::debug!("failed to read from stream: {:?}", e);
162 })?;
163 if n == 0 {
164 break;
166 }
167 }
168
169 Ok(())
170 }
171 (Err(reader_err), Ok(_)) => {
172 tracing::error!(
173 "reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
174 );
175 anyhow::bail!(
176 "reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
177 );
178 }
179 (Ok(_), Err(writer_err)) => {
180 tracing::error!(
181 "writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
182 );
183 anyhow::bail!(
184 "writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
185 );
186 }
187 (Err(reader_err), Err(writer_err)) => {
188 tracing::error!(
189 "both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
190 );
191 anyhow::bail!(
192 "both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
193 );
194 }
195 }
196 });
197
198 let prologue = Some(ResponseStreamPrologue { error: None });
201
202 let stream_sender = StreamSender {
204 tx: bytes_tx,
205 prologue,
206 };
207
208 Ok(stream_sender)
209 }
210}
211
212async fn handle_reader(
213 framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
214 context: Arc<dyn AsyncEngineContext>,
215 alive_tx: tokio::sync::oneshot::Sender<()>,
216) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
217 let mut framed_reader = framed_reader;
218 let mut alive_tx = alive_tx;
219 loop {
220 tokio::select! {
221 msg = framed_reader.next() => {
222 match msg {
223 Some(Ok(two_part_msg)) => {
224 match two_part_msg.optional_parts() {
225 (Some(bytes), None) => {
226 let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
227 Ok(msg) => msg,
228 Err(_) => {
229 panic!("fatal error - invalid control message detected");
231 }
232 };
233
234 match msg {
235 ControlMessage::Stop => {
236 context.stop();
237 }
238 ControlMessage::Kill => {
239 context.kill();
240 }
241 ControlMessage::Sentinel => {
242 panic!("received a sentinel message; this should never happen");
244 }
245 }
246 }
247 _ => {
248 panic!("received a non-control message; this should never happen");
249 }
250 }
251 }
252 Some(Err(e)) => {
253 panic!("fatal error - failed to decode message from stream; invalid line protocol: {e:?}");
256 }
257 None => {
258 tracing::debug!("tcp stream closed by server");
259 break;
260 }
261 }
262 }
263 _ = alive_tx.closed() => {
264 break;
265 }
266 }
267 }
268 framed_reader
269}
270
271async fn handle_writer(
272 mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
273 mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
274 alive_rx: tokio::sync::oneshot::Receiver<()>,
275 context: Arc<dyn AsyncEngineContext>,
276) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
277 loop {
278 let msg = tokio::select! {
279 biased;
280
281 _ = context.killed() => {
282 tracing::trace!("context kill signal received; shutting down");
283 break;
284 }
285
286 _ = context.stopped() => {
287 tracing::trace!("context stop signal received; shutting down");
288 break;
289 }
290
291 msg = bytes_rx.recv() => {
292 match msg {
293 Some(msg) => msg,
294 None => {
295 tracing::trace!("response channel closed; shutting down");
296 break;
297 }
298 }
299 }
300 };
301
302 if let Err(e) = framed_writer.send(msg).await {
303 tracing::trace!(
304 "failed to send message to network; possible disconnect: {:?}",
305 e
306 );
307 break;
308 }
309 }
310
311 let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
313 let msg = TwoPartMessage::from_header(message.into());
314 framed_writer.send(msg).await?;
315
316 drop(alive_rx);
317 Ok(framed_writer)
318}