dynamo_runtime/pipeline/network/tcp/
client.rs1use std::sync::Arc;
5
6use futures::{SinkExt, StreamExt};
7use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
8use tokio::{
9 io::AsyncWriteExt,
10 net::TcpStream,
11 time::{self, Duration, Instant},
12};
13use tokio_util::codec::{FramedRead, FramedWrite};
14
15use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
16use crate::engine::AsyncEngineContext;
17use crate::pipeline::network::{
18 ConnectionInfo, ResponseStreamPrologue, StreamSender,
19 codec::{TwoPartCodec, TwoPartMessage},
20 tcp::StreamType,
21};
22use crate::{ErrorContext, Result, error}; #[allow(dead_code)]
25pub struct TcpClient {
26 worker_id: String,
27}
28
29impl Default for TcpClient {
30 fn default() -> Self {
31 TcpClient {
32 worker_id: uuid::Uuid::new_v4().to_string(),
33 }
34 }
35}
36
37impl TcpClient {
38 pub fn new(worker_id: String) -> Self {
39 TcpClient { worker_id }
40 }
41
42 async fn connect(address: &str) -> std::io::Result<TcpStream> {
43 let backoff = std::time::Duration::from_millis(200);
45 loop {
46 match TcpStream::connect(address).await {
47 Ok(socket) => {
48 socket.set_nodelay(true)?;
49 return Ok(socket);
50 }
51 Err(e) => {
52 if e.kind() == std::io::ErrorKind::AddrNotAvailable {
53 tracing::warn!("retry warning: failed to connect: {:?}", e);
54 tokio::time::sleep(backoff).await;
55 } else {
56 return Err(e);
57 }
58 }
59 }
60 }
61 }
62
63 pub async fn create_response_stream(
64 context: Arc<dyn AsyncEngineContext>,
65 info: ConnectionInfo,
66 ) -> Result<StreamSender> {
67 let info =
68 TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
69 tracing::trace!("Creating response stream for {:?}", info);
70
71 if info.stream_type != StreamType::Response {
72 return Err(error!(
73 "Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
74 info.stream_type
75 ));
76 }
77
78 if info.context != context.id() {
79 return Err(error!(
80 "Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
81 context.id(),
82 info.context
83 ));
84 }
85
86 let stream = TcpClient::connect(&info.address).await?;
87 let (read_half, write_half) = tokio::io::split(stream);
88
89 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
90 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
91
92 let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
98
99 let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
100
101 let handshake = CallHomeHandshake {
103 subject: info.subject,
104 stream_type: StreamType::Response,
105 };
106
107 let handshake_bytes = match serde_json::to_vec(&handshake) {
108 Ok(hb) => hb,
109 Err(err) => {
110 return Err(error!(
111 "create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
112 ));
113 }
114 };
115 let msg = TwoPartMessage::from_header(handshake_bytes.into());
116
117 framed_writer
119 .send(msg)
120 .await
121 .map_err(|e| error!("failed to send handshake: {:?}", e))?;
122
123 let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
125
126 let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
129
130 tokio::spawn(async move {
131 let (reader, writer) = tokio::join!(reader_task, writer_task);
133
134 match (reader, writer) {
135 (Ok(reader), Ok(writer)) => {
136 let reader = reader.into_inner();
137
138 let writer = match writer {
139 Ok(writer) => writer.into_inner(),
140 Err(e) => {
141 tracing::error!("failed to join writer task: {:?}", e);
142 return Err(e);
143 }
144 };
145
146 let mut stream = reader.unsplit(writer);
147
148 let mut buf = vec![0u8; 1024];
151 let deadline = Instant::now() + Duration::from_secs(10);
152 loop {
153 let n = time::timeout_at(deadline, stream.read(&mut buf))
154 .await
155 .inspect_err(|_| {
156 tracing::debug!("server did not close socket within the deadline");
157 })?
158 .inspect_err(|e| {
159 tracing::debug!("failed to read from stream: {:?}", e);
160 })?;
161 if n == 0 {
162 break;
164 }
165 }
166
167 Ok(())
168 }
169 _ => {
170 tracing::error!("failed to join reader and writer tasks");
171 anyhow::bail!("failed to join reader and writer tasks");
172 }
173 }
174 });
175
176 let prologue = Some(ResponseStreamPrologue { error: None });
179
180 let stream_sender = StreamSender {
182 tx: bytes_tx,
183 prologue,
184 };
185
186 Ok(stream_sender)
187 }
188}
189
190async fn handle_reader(
191 framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
192 context: Arc<dyn AsyncEngineContext>,
193 alive_tx: tokio::sync::oneshot::Sender<()>,
194) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
195 let mut framed_reader = framed_reader;
196 let mut alive_tx = alive_tx;
197 loop {
198 tokio::select! {
199 msg = framed_reader.next() => {
200 match msg {
201 Some(Ok(two_part_msg)) => {
202 match two_part_msg.optional_parts() {
203 (Some(bytes), None) => {
204 let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
205 Ok(msg) => msg,
206 Err(_) => {
207 panic!("fatal error - invalid control message detected");
209 }
210 };
211
212 match msg {
213 ControlMessage::Stop => {
214 context.stop();
215 }
216 ControlMessage::Kill => {
217 context.kill();
218 }
219 ControlMessage::Sentinel => {
220 panic!("received a sentinel message; this should never happen");
222 }
223 }
224 }
225 _ => {
226 panic!("received a non-control message; this should never happen");
227 }
228 }
229 }
230 Some(Err(_)) => {
231 panic!("fatal error - failed to decode message from stream; invalid line protocol");
234 }
235 None => {
236 tracing::debug!("tcp stream closed by server");
237 break;
238 }
239 }
240 }
241 _ = alive_tx.closed() => {
242 break;
243 }
244 }
245 }
246 framed_reader
247}
248
249async fn handle_writer(
250 mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
251 mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
252 alive_rx: tokio::sync::oneshot::Receiver<()>,
253 context: Arc<dyn AsyncEngineContext>,
254) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
255 loop {
256 let msg = tokio::select! {
257 biased;
258
259 _ = context.killed() => {
260 tracing::trace!("context kill signal received; shutting down");
261 break;
262 }
263
264 _ = context.stopped() => {
265 tracing::trace!("context stop signal received; shutting down");
266 break;
267 }
268
269 msg = bytes_rx.recv() => {
270 match msg {
271 Some(msg) => msg,
272 None => {
273 tracing::trace!("response channel closed; shutting down");
274 break;
275 }
276 }
277 }
278 };
279
280 if let Err(e) = framed_writer.send(msg).await {
281 tracing::trace!(
282 "failed to send message to network; possible disconnect: {:?}",
283 e
284 );
285 break;
286 }
287 }
288
289 let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
291 let msg = TwoPartMessage::from_header(message.into());
292 framed_writer.send(msg).await?;
293
294 drop(alive_rx);
295 Ok(framed_writer)
296}